;; protocol.scm: base type declarations for Thrift protocol implementations
;; Copyright (C) 2012 Julian Graham

;; r6rs-thrift is free software: you can redistribute it and/or modify
;; it under the terms of the GNU General Public License as published by
;; the Free Software Foundation, either version 3 of the License, or
;; (at your option) any later version.

;; This program is distributed in the hope that it will be useful,
;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
;; GNU General Public License for more details.

;; You should have received a copy of the GNU General Public License
;; along with this program.  If not, see <http://www.gnu.org/licenses/>.

#!r6rs

(library (thrift protocol)
  (export thrift:protocol
	  thrift:make-protocol
	  thrift:protocol?

	  thrift:protocol-deserializer
	  thrift:protocol-serializer

	  thrift:serializer
	  thrift:make-serializer
	  thrift:serializer?

	  thrift:serializer-write-message-begin
	  thrift:serializer-write-message-end
	  thrift:serializer-write-struct-begin
	  thrift:serializer-write-struct-end
	  thrift:serializer-write-field-begin
	  thrift:serializer-write-field-end
	  thrift:serializer-write-field-stop
	  thrift:serializer-write-map-begin
	  thrift:serializer-write-map-end
	  thrift:serializer-write-list-begin
	  thrift:serializer-write-list-end
	  thrift:serializer-write-set-begin
	  thrift:serializer-write-set-end
	  thrift:serializer-write-bool
	  thrift:serializer-write-byte
	  thrift:serializer-write-i16
	  thrift:serializer-write-i32
	  thrift:serializer-write-i64
	  thrift:serializer-write-double
	  thrift:serializer-write-string

	  thrift:deserializer
	  thrift:make-deserializer
	  thrift:deserializer?

	  thrift:deserializer-read-message-begin
	  thrift:deserializer-read-message-end
	  thrift:deserializer-read-struct-begin
	  thrift:deserializer-read-struct-end
	  thrift:deserializer-read-field-begin
	  thrift:deserializer-read-field-end
	  thrift:deserializer-read-map-begin
	  thrift:deserializer-read-map-end
	  thrift:deserializer-read-list-begin
	  thrift:deserializer-read-list-end
	  thrift:deserializer-read-set-begin
	  thrift:deserializer-read-set-end
	  thrift:deserializer-read-bool
	  thrift:deserializer-read-byte
	  thrift:deserializer-read-i16
	  thrift:deserializer-read-i32
	  thrift:deserializer-read-i64
	  thrift:deserializer-read-double
	  thrift:deserializer-read-string

	  thrift:byte->wire-type
	  thrift:wire-type->byte

	  thrift:byte->message-type
	  thrift:message-type->byte

	  thrift:send-message
	  thrift:receive-message)
	
  (import (rnrs)
	  (thrift private)
	  (thrift transport))

  (define (thrift:message-type->byte message-type)
    (case message-type
      ((thrift:message-type call) 1)
      ((thrift:message-type reply) 2)
      ((thrift:message-type exception) 3)
      ((thrift:message-type oneway) 4)
      (else raise (make-assertion-violation))))

  (define (thrift:byte->message-type byte)
    (case byte
      ((1) (thrift:message-type call))
      ((2) (thrift:message-type reply))
      ((3) (thrift:message-type exception))
      ((4) (thrift:message-type oneway))
      (else (raise (make-assertion-violation)))))

  (define (thrift:wire-type->byte wire-type)
    (case wire-type
      ((thrift:wire-type stop) 0)
      ((thrift:wire-type void) 1)
      ((thrift:wire-type bool) 2)
      ((thrift:wire-type byte) 3)
      ((thrift:wire-type double) 4)
      ((thrift:wire-type i16) 6)
      ((thrift:wire-type i32) 8)
      ((thrift:wire-type i64) 10)
      ((thrift:wire-type string) 11)
      ((thrift:wire-type struct) 12)
      ((thrift:wire-type map) 13)
      ((thrift:wire-type set) 14)
      ((thrift:wire-type list) 15)
      ((thrift:wire-type enum) 16)
      (else (raise (make-assertion-violation)))))

  (define (thrift:byte->wire-type byte)
    (case byte
      ((0) (thrift:wire-type stop))
      ((1) (thrift:wire-type void))
      ((2) (thrift:wire-type bool))
      ((3) (thrift:wire-type byte))
      ((4) (thrift:wire-type double))
      ((6) (thrift:wire-type i16))
      ((8) (thrift:wire-type i32))
      ((10) (thrift:wire-type i64))
      ((11) (thrift:wire-type string))
      ((12) (thrift:wire-type struct))
      ((13) (thrift:wire-type map))
      ((14) (thrift:wire-type set))
      ((15) (thrift:wire-type list))
      ((16) (thrift:wire-type enum))
      (else (raise (make-assertion-violation)))))

  (define-record-type (thrift:serializer 
		       thrift:make-serializer 
		       thrift:serializer?)

    (fields write-message-begin
	    write-message-end
	    write-struct-begin
	    write-struct-end
	    write-field-begin
	    write-field-end
	    write-field-stop
	    write-map-begin
	    write-map-end
	    write-list-begin
	    write-list-end
	    write-set-begin
	    write-set-end
	    write-bool
	    write-byte
	    write-i16
	    write-i32
	    write-i64
	    write-double
	    write-string))

  (define-record-type (thrift:deserializer 
		       thrift:make-deserializer 
		       thrift:deserializer?)

    (fields read-message-begin
	    read-message-end
	    read-struct-begin
	    read-struct-end
	    read-field-begin
	    read-field-end
	    read-map-begin
	    read-map-end
	    read-list-begin
	    read-list-end
	    read-set-begin
	    read-set-end
	    read-bool
	    read-byte
	    read-i16
	    read-i32
	    read-i64
	    read-double
	    read-string))

  (define-record-type (thrift:protocol thrift:make-protocol thrift:protocol?)
    (fields serializer deserializer))

  (define (thrift:write-list port list type protocol)
    (define serializer (thrift:protocol-serializer protocol))

    ((thrift:serializer-write-list-begin serializer) port list)
    (vector-for-each
     (lambda (v) (thrift:write-value port v type protocol)) list)
    ((thrift:serializer-write-list-end serializer) port list))

  (define (thrift:write-map port map key-type value-type protocol)
    (define serializer (thrift:protocol-serializer protocol))

    ((thrift:serializer-write-map-begin serializer) port map)
    (let-values (((keys values) (hashtable-entries map)))
      (let ((len (vector-length keys)))
	(let loop ((i 0))
	  (or (eqv? i len)
	      (begin
		(thrift:write-value port (vector-ref keys i) key-type protocol)
		(thrift:write-value
		 port (vector-ref values i) value-type protocol)
		(loop (+ i 1)))))))

    ((thrift:serializer-write-map-end serializer) port map))

  (define (thrift:write-set port set type protocol)
    (define serializer (thrift:protocol-serializer protocol))

    ((thrift:serializer-write-set-begin serializer) port set)
    (for-each (lambda (v) (thrift:write-value port v type protocol)) set)
    ((thrift:serializer-write-set-end serializer) port set))

  (define (thrift:write-enum port value type protocol)
    (define serializer (thrift:protocol-serializer protocol))

    (define (value->ordinal value)
      (let loop ((enum-values (thrift:enum-field-type-descriptor-values type)))
	(if (null? enum-values)
	    (raise (make-assertion-violation))
	    (let ((enum-value (car enum-values))) 
	      (if (eq? (thrift:enum-value-descriptor-name enum-value) value)
		  (thrift:enum-value-descriptor-ordinal enum-value)
		  (loop (cdr enum-values)))))))
    
    ((thrift:serializer-write-i32 serializer) port (value->ordinal value)))

  (define (thrift:write-value port value type protocol)
    (define serializer (thrift:protocol-serializer protocol))

      (cond ((thrift:enum-field-type-descriptor? type)
	     (thrift:write-enum port value type protocol))

	    ((thrift:struct-field-type-descriptor-exception? type)
	     (thrift:write-exception port value protocol))
	    ((thrift:struct-field-type-descriptor? type)
	     (thrift:write-struct port value protocol))
	    
	    ((thrift:parameterized-field-type-descriptor? type)
	     (let ((params 
		    (thrift:parameterized-field-type-descriptor-parameters 
		     type)))

	       (case (thrift:field-type-descriptor-name type)
		 (("list") 
		  (thrift:write-list port value (car params) protocol))
		 (("map") 
		  (thrift:write-map 
		   port value (car params) (cadr params) protocol))
		 (("set") (thrift:write-set port value (car params) protocol))
		 (else (raise (make-assertion-violation))))))

	    ((eq? type thrift:field-type-bool)
	     ((thrift:serializer-write-bool serializer) port value))
	    ((eq? type thrift:field-type-byte)
	     ((thrift:serializer-write-byte serializer) port value))
	    ((eq? type thrift:field-type-i16)
	     ((thrift:serializer-write-i16 serializer) port value))
	    ((eq? type thrift:field-type-i32)
	     ((thrift:serializer-write-i32 serializer) port value))
	    ((eq? type thrift:field-type-i64)
	     ((thrift:serializer-write-i64 serializer) port value))
	    ((eq? type thrift:field-type-double)
	     ((thrift:serializer-write-double serializer) port value))
	    ((eq? type thrift:field-type-string)
	     ((thrift:serializer-write-string serializer) port value))
	    (else (raise (make-assertion-violation)))))

  (define (thrift:write-field port field protocol)
    (let* ((descriptor (thrift:field-field-descriptor field))
	   (type (thrift:resolve-type 
		  (thrift:field-descriptor-type descriptor)))
	   (value (thrift:field-value field)))
      (thrift:write-value port value type protocol)))

  (define (thrift:write-struct port struct protocol)
    (define serializer (thrift:protocol-serializer protocol))

    ((thrift:serializer-write-struct-begin serializer) port struct)
    (for-each (lambda (field)
		(if (thrift:field-has-value? field)
		    (begin
		      ((thrift:serializer-write-field-begin serializer)
		       port field)
		      (thrift:write-field port field protocol)
		      ((thrift:serializer-write-field-end serializer) 
		       port field))))
	      (thrift:struct-fields struct))
    ((thrift:serializer-write-field-stop serializer) port)
    ((thrift:serializer-write-struct-end serializer) port struct))

  (define (thrift:write-exception port exception protocol)
    (define serializer (thrift:protocol-serializer protocol))

    ((thrift:serializer-write-struct-begin serializer) port exception)
    (for-each (lambda (field)
		(if (thrift:field-has-value? field)
		    (begin
		      ((thrift:serializer-write-field-begin serializer) 
		       port field)
		      (thrift:write-field port field protocol)
		      ((thrift:serializer-write-field-end serializer) 
		       port field))))
	      (thrift:exception-fields exception))    
    ((thrift:serializer-write-field-stop serializer) port)
    ((thrift:serializer-write-struct-end serializer) port exception))

  (define (thrift:read-list port protocol type)
    (define deserializer (thrift:protocol-deserializer protocol))

    (let-values (((wire-type size) 
		  ((thrift:deserializer-read-list-begin protocol) port)))
      (let ((v (make-vector size)))
	(let loop ((i 0))
	  (or (eqv? i size)
	      (vector-set! v i (thrift:read-field-value
				port protocol wire-type type))))
	v)))

  (define (thrift:read-list port protocol type)
    (define deserializer (thrift:protocol-deserializer protocol))
    
    (let-values (((wire-type size) 
		  ((thrift:deserializer-read-list-begin protocol) port)))
      (let ((v (make-vector size)))
	(let loop ((i size))
	  (or (eqv? i 0)
	      (begin 
		(vector-set! 
		 v i (thrift:read-field-value port protocol wire-type type))
		(loop (- i 1)))))
	((thrift:deserializer-read-list-end protocol) port v))))
  
  (define (thrift:read-map port protocol key-type value-type)
    (define deserializer (thrift:protocol-deserializer protocol))
    
    (let-values (((key-wire-type value-wire-type size)
		  ((thrift:deserializer-read-map-begin protocol) port)))
      (let ((ht (make-hashtable (thrift:type->hash-function key-type)
				(thrift:type->equivalence-function key-type))))
	(let loop ((i size))
	  (or (eqv? i 0)
	      (begin
		(hashtable-set! ht 
				(thrift:read-field-value 
				 port protocol key-wire-type key-type)
				(thrift:read-field-value
				 port protocol value-wire-type value-type))
		(loop (- i 1)))))

	((thrift:deserializer-read-map-end protocol) port ht))))

  (define (thrift:read-set port protocol type)
    (define deserializer (thrift:protocol-deserializer protocol))
    
    (let-values (((wire-type size) 
		  ((thrift:deserializer-read-set-begin protocol) port)))

      (let loop ((i size) (s (list)))
	  (if (eqv? i 0)
	      ((thrift:deserializer-read-set-end protocol) 
	       protocol (reverse s))
	      (loop (- i 1) 
		    (cons (thrift:read-field-value 
			   port protocol wire-type type)
			  s))))))

  (define (thrift:read-enum port protocol type)
    (define deserializer (thrift:protocol-deserializer protocol))

    (define (ordinal->value ordinal)
      (let loop ((enum-values (thrift:enum-field-type-descriptor-values type)))
	(if (null? enum-values)
	    (raise (make-assertion-violation))
	    (let ((enum-value (car enum-values))) 
	      (if (eqv? (thrift:enum-value-descriptor-ordinal enum-value) 
			ordinal)
		  (thrift:enum-value-descriptor-name enum-value)
		  (loop (cdr enum-values)))))))
    
    (ordinal->value ((thrift:deserializer-read-i32 deserializer) port)))

  (define (thrift:read-field-value port protocol wire-type . type)
    (define type-reference (if (null? type) #f (car type)))
    (define deserializer (thrift:protocol-deserializer protocol))
    
    (case wire-type
      ((bool) ((thrift:deserializer-read-bool deserializer) port))
      ((byte) ((thrift:deserializer-read-byte deserializer) port))
      ((double) ((thrift:deserializer-read-double deserializer) port))
      ((i16) ((thrift:deserializer-read-i16 deserializer) port))
      ((i32) ((thrift:deserializer-read-i32 deserializer) port))
      ((i64) ((thrift:deserializer-read-i64 deserializer) port))
      ((string) ((thrift:deserializer-read-string deserializer) port))

      ((struct)
       (if type-reference 
	   (thrift:read-struct
	    (thrift:make-struct-builder type-reference) port protocol)
	   (thrift:read-struct #f port protocol)))

      ((list) 
       ((thrift:read-list 
	 port protocol
	 (and type-reference
	      (let ((type-descriptor (thrift:resolve-type type-reference)))
		(car (thrift:parameterized-field-type-descriptor-parameters 
		      type-descriptor)))))))

      ((map) 
       (if type-reference
	   (let* ((type-descriptor (thrift:resolve-type type-reference))
		  (params 
		   (thrift:parameterized-field-type-descriptor-parameters 
		    type-descriptor))
		  (key-type (car params))
		  (value-type (cadr params)))
	     (thrift:read-map port protocol key-type wire-type))
	   (thrift:read-map port protocol #f #f)))

      ((set)
       ((thrift:read-set
	 port protocol 
	 (and type-reference
	      (let ((type-descriptor (thrift:resolve-type type-reference)))
		(car (thrift:parameterized-field-type-descriptor-parameters 
		      type-descriptor)))))))

      ((enum) 
       (if type-reference
	   (let ((type-descriptor (thrift:resolve-type type-reference)))
	     (thrift:read-enum port protocol type-descriptor))
	   (begin ((thrift:deserializer-read-i32 protocol) port) #f)))
      (else (raise (make-assertion-violation)))))

  (define (thrift:read-struct builder port protocol)
    (define deserializer (thrift:protocol-deserializer protocol))
    
    (define (read-type-index)
      (call-with-values 
	  (lambda () 
	    ((thrift:deserializer-read-field-begin deserializer) port))
	list))

    (define field-table (make-eqv-hashtable))

    (for-each (lambda (field)
		(hashtable-set! field-table
				(thrift:field-descriptor-index
				 (thrift:field-field-descriptor field)) 
				field))
	      (thrift:struct-builder-fields builder))

    ((thrift:deserializer-read-struct-begin deserializer) port)
    (do ((type-index (read-type-index) (read-type-index)))
	((eq? (car type-index) (thrift:wire-type stop)))
      (let* ((field (hashtable-ref field-table (cadr type-index) #f)))
	(if field
	    (thrift:set-field-value! 
	     field (thrift:read-field-value 
		    port protocol (car type-index)
		    (thrift:field-descriptor-type 
		     (thrift:field-field-descriptor field))))
	    (thrift:read-field-value port protocol (car type-index)))
	((thrift:deserializer-read-field-end deserializer) port field)))
	
    ((thrift:deserializer-read-struct-end deserializer) port builder)
    (thrift:struct-builder-build builder))

  (define (thrift:send-message 
	   transport protocol sequence type-reference method-name . args)
    (define port (thrift:transport-output-port transport))
    (define serializer (thrift:protocol-serializer protocol))
    (define descriptor (thrift:resolve-service type-reference))

    (thrift:increment-sequence! sequence)

    (let ((function-descriptor
	   (find (lambda (function) 
		   (equal? (thrift:function-descriptor-name function) 
			   method-name))
		 (thrift:service-descriptor-functions descriptor)))
	  (message (thrift:make-message 
		    method-name
		    (thrift:message-type call) 
		    (thrift:sequence-value sequence))))
      ((thrift:serializer-write-message-begin serializer) port message)
      ((thrift:serializer-write-struct-begin serializer) port #f)
      (let loop ((args args)
		 (field-descriptors 
		  (thrift:function-descriptor-arguments function-descriptor)))
	(if (not (null? args))
	    (let* ((arg (car args))
		   (field-descriptor (car field-descriptors))
		   (field (thrift:make-field field-descriptor arg)))

	      ((thrift:serializer-write-field-begin serializer) port field)
	      (thrift:write-field port field protocol)
	      ((thrift:serializer-write-field-end serializer) port field)

	      (loop (cdr args) (cdr field-descriptors)))))
      ((thrift:serializer-write-field-stop serializer) port)
      ((thrift:serializer-write-struct-end serializer) port #f)
      ((thrift:serializer-write-message-end serializer) port message)))

  (define (thrift:receive-message
	   transport protocol sequence type-reference method-name)
    (define port (thrift:transport-input-port transport))
    (define deserializer (thrift:protocol-deserializer protocol))
    (define descriptor (thrift:resolve-service type-reference))

    (let-values (((name type seq-id) 
		  ((thrift:deserializer-read-message-begin deserializer) 
		   port)))
      (let ((message (thrift:make-message name type seq-id)))
	(if (eq? (thrift:message-type-symbol message) 
		 (thrift:message-type exception))
	    (begin
	      ((thrift:deserializer-read-message-end deserializer) 
	       port message)
	      (raise (make-assertion-violation))))

	(if (not (eqv? (thrift:message-seq-id message)
		       (thrift:sequence-value sequence)))
	    (raise (condition
		    (make-message-condition
		     (string-append
		      "Message sequence id " 
		      (number->string (thrift:message-seq-id message))
		      " does not match client sequence id "
		      (number->string (thrift:sequence-value sequence))))
		    (make-assertion-violation))))
	
	(let* ((function-descriptor
		(find (lambda (function) 
			(equal? (thrift:function-descriptor-name function) 
				method-name))
		      (thrift:service-descriptor-functions descriptor)))
	       (return-type-reference
		(thrift:function-descriptor-return-type function-descriptor))
	       (return-type (thrift:resolve-type return-type-reference))
	       (wire-type (thrift:field-type-descriptor-wire-type 
			   return-type))
	       (return-value 
		(if (not (eq? return-type thrift:field-type-void))
		    (begin
		      ((thrift:deserializer-read-struct-begin deserializer) 
		       port)
		      ((thrift:deserializer-read-field-begin deserializer) 
		       port)
		      (let ((val (thrift:read-field-value
				  port protocol wire-type 
				  return-type-reference)))
			((thrift:deserializer-read-field-end deserializer) 
			 port #f)
			((thrift:deserializer-read-struct-end deserializer) 
			 port #f)
			val))

		      (if #f #f))))
	  ((thrift:deserializer-read-message-end deserializer) port message)
	  return-value))))
)
